{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "bdaa9a5f-e29b-4df1-a9da-30b18999f49d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at D:/gitee/bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']\n",
      "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from transformers import BertTokenizer, BertForMaskedLM\n",
    "import pandas as pd \n",
    "import numpy as np \n",
    "# 加载预训练模型和分词器\n",
    "tokenizer = BertTokenizer.from_pretrained(\"D:/gitee/bert-base-uncased\")\n",
    "model = BertForMaskedLM.from_pretrained(\"D:/gitee/bert-base-uncased\")\n",
    "\n",
    "## 用pandas的dataframe格式存储\n",
    "options = pd.read_table(\"初中完型填空对应的选项.txt\",header=None,sep=\",\")\n",
    "options.columns = [\"A\",\"B\",\"C\",\"D\",\"RT\"]\n",
    "with open(\"初中完型填空-Copy1.txt\",\"r\") as file:\n",
    "    tidy_data = file.read()\n",
    "tidy_data = tidy_data.split(\"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "7bba9583-8be6-45d1-a140-79363c7e9595",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['crazy', 'beautiful', 'lazy', 'ugly']\n"
     ]
    }
   ],
   "source": [
    "words = list(options.iloc[0,:4])\n",
    "print(words)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "fe33b2be-d08a-416c-9e18-58ce05ca8811",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Word: crazy, Probability: 15.384073257446289\n",
      "Word: beautiful, Probability: 14.479395866394043\n",
      "Word: lazy, Probability: 13.185140609741211\n",
      "Word: ugly, Probability: 14.475358963012695\n",
      "Word: think, Probability: 19.187870025634766\n",
      "Word: hope, Probability: 16.129600524902344\n",
      "Word: suggest, Probability: 12.45032787322998\n",
      "Word: ask, Probability: 18.128929138183594\n",
      "Word: wrong, Probability: 19.49198341369629\n",
      "Word: welcome, Probability: 8.906497955322266\n",
      "Word: important, Probability: 13.835725784301758\n",
      "Word: interesting, Probability: 13.532438278198242\n",
      "Word: lessons, Probability: 17.660797119140625\n",
      "Word: symbols, Probability: 9.39571475982666\n",
      "Word: novels, Probability: 12.453939437866211\n",
      "Word: results, Probability: 18.832393646240234\n",
      "Word: stupid, Probability: 15.29840087890625\n",
      "Word: interested, Probability: 17.857742309570312\n",
      "Word: sad, Probability: 20.25552749633789\n",
      "Word: glad, Probability: 14.4566011428833\n",
      "Word: although, Probability: 16.11528205871582\n",
      "Word: if, Probability: 18.08999252319336\n",
      "Word: before, Probability: 19.832168579101562\n",
      "Word: even though, Probability: -2.272726535797119\n",
      "Word: sees, Probability: 15.9879150390625\n",
      "Word: likes, Probability: 18.262699127197266\n",
      "Word: wakes, Probability: 14.946699142456055\n",
      "Word: catches, Probability: 18.267948150634766\n",
      "Word: suddenly, Probability: 11.924894332885742\n",
      "Word: early, Probability: 15.743229866027832\n",
      "Word: recently, Probability: 11.323549270629883\n",
      "Word: together, Probability: 15.698836326599121\n",
      "Word: better, Probability: 18.09197425842285\n",
      "Word: angrier, Probability: -0.7323864698410034\n",
      "Word: busier, Probability: -2.643861770629883\n",
      "Word: heavier, Probability: 14.651044845581055\n",
      "Word: purpose, Probability: 10.684463500976562\n",
      "Word: opinion, Probability: 17.162281036376953\n",
      "Word: friendship, Probability: 19.558452606201172\n",
      "Word: habit, Probability: 9.799904823303223\n"
     ]
    }
   ],
   "source": [
    "pre_word = list()\n",
    "\n",
    "for i in range(len(tidy_data)):\n",
    "    prob = list() #保存当前句子下每种组合的可能性,选出最大的一组\n",
    "    sentence = tidy_data[i]\n",
    "    inputs = tokenizer.encode_plus(sentence, return_tensors=\"pt\", add_special_tokens=True)\n",
    "\n",
    "    input_ids = inputs[\"input_ids\"]\n",
    "    ## 记录掩码的位置\n",
    "    mask_index = torch.where(input_ids == tokenizer.mask_token_id)[1]\n",
    "    \n",
    "    words = list(options.iloc[i,:4])\n",
    "    for word in words:\n",
    "        # 构建待预测的句子\n",
    "        masked_sentence = sentence.replace(\"[MASK]\", word)\n",
    "        \n",
    "        # 使用分词器对句子进行编码\n",
    "        input_ids = tokenizer.encode(masked_sentence, add_special_tokens=True, return_tensors=\"pt\")\n",
    "        \n",
    "        # 使用模型进行推理\n",
    "        outputs = model(input_ids)\n",
    "        predictions = outputs.logits\n",
    "        \n",
    "        # 获取预测结果对应的概率值\n",
    "        predicted_token_id = tokenizer.convert_tokens_to_ids(word)\n",
    "        predicted_probability = predictions[0, mask_index, predicted_token_id].item()\n",
    "        \n",
    "        # 打印单词和对应的概率值\n",
    "        print(f\"Word: {word}, Probability: {predicted_probability}\")\n",
    "        prob.append(predicted_probability)\n",
    "    pre_word.append(options.iloc[i,np.argmax(prob)])\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ab20a48a-0fd7-4679-a603-1be056ae3a9c",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'options' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_20224\\3937129061.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mright_word\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0moptions\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mloc\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0moptions\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0miloc\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m4\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m10\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_20224\\3937129061.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[1;34m(.0)\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mright_word\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0moptions\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mloc\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0moptions\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0miloc\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m4\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m10\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;31mNameError\u001b[0m: name 'options' is not defined"
     ]
    }
   ],
   "source": [
    "right_word = [options.loc[i,options.iloc[i,4]] for i in range(10)]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20dd0708",
   "metadata": {},
   "source": [
    "预测结果与正确选项对比"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "9456a942-9caa-4c22-877a-27ad6e43e36c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['crazy --- beautiful',\n",
       " 'think --- think',\n",
       " 'wrong --- wrong',\n",
       " 'results --- results',\n",
       " 'sad --- sad',\n",
       " 'before --- if',\n",
       " 'catches --- likes',\n",
       " 'early --- together',\n",
       " 'better --- better',\n",
       " 'friendship --- friendship']"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "[pre_word[i] + \" --- \" + right_word[i] for i in range(10)]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
